{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plots for SSLSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "os.chdir('../..')\n",
    "sys.path.insert(1, os.path.join(sys.path[0], '../..'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger('fontTools').setLevel(logging.ERROR)\n",
    "logging.getLogger('matplotlib').setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebooks.articles.utils import (\n",
    "    evaluate_sv,\n",
    "    plot_inter_class_similarity,\n",
    "    plot_intra_class_similarity,\n",
    ")\n",
    "\n",
    "from notebooks.evaluation.sv_visualization import (\n",
    "    det_curve,\n",
    ")\n",
    "\n",
    "from plotnine import *\n",
    "import patchworklib as pw\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Model:\n",
    "\n",
    "    scores: List[float] = None\n",
    "    targets: List[int] = None\n",
    "    embeddings: Dict[str, torch.Tensor] = None\n",
    "\n",
    "\n",
    "def get_models_for_visualization(scores, names=None):\n",
    "    if names is None:\n",
    "        names = list(scores.keys())\n",
    "\n",
    "    models = {\n",
    "        k:Model(v['scores'], v['targets'])\n",
    "        for k, v\n",
    "        in scores.items()\n",
    "        if k in names\n",
    "    }\n",
    "\n",
    "    return models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_RESNET = {\n",
    "    \"SimCLR\":       \"models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/\",\n",
    "    \"MoCo\":         \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "    \"SwAV\":         \"models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/\",\n",
    "    \"VICReg\":       \"models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/\",\n",
    "    \"DINO\":         \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "    \"Supervised\":   \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/\",\n",
    "}\n",
    "\n",
    "MODELS_ECAPA = {\n",
    "    \"SimCLR\":       \"models/ssl/voxceleb2/simclr/simclr_enc-ECAPATDNN-1024_proj-none_t-0.03/\",\n",
    "    \"MoCo\":         \"models/ssl/voxceleb2/moco/moco_enc-ECAPATDNN-1024_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "    \"SwAV\":         \"models/ssl/voxceleb2/swav/swav_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/\",\n",
    "    \"VICReg\":       \"models/ssl/voxceleb2/vicreg/vicreg_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/\",\n",
    "    \"DINO\":         \"models/ssl/voxceleb2/dino/dino_enc-ECAPATDNN-1024_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "    \"Supervised\":   \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2/\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Palette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mizani.palettes import hue_pal\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "palette = hue_pal(h=0.01, l=0.6, s=0.65, color_space=\"hls\")(len(MODELS_ECAPA.keys()))\n",
    "\n",
    "plt.figure(figsize=(8, 1))\n",
    "for i, color in enumerate(palette):\n",
    "    plt.bar(i, 1, color=color)\n",
    "plt.xticks(range(len(MODELS_ECAPA.keys())), MODELS_ECAPA.keys())\n",
    "plt.show()\n",
    "\n",
    "palette = dict(zip(MODELS_ECAPA.keys(), palette))\n",
    "print(palette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_ECAPA_ORDER = list(MODELS_ECAPA.keys())\n",
    "MODELS_ECAPA_PALETTE = palette\n",
    "\n",
    "MODELS_ECAPA_ORDER, MODELS_ECAPA_PALETTE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_ECAPA_PALETTE = {\n",
    "    'SimCLR': '#db5f57',\n",
    "    'MoCo': '#57db5f',\n",
    "    'SwAV': '#d3db57',\n",
    "    'VICReg': '#57d3db',\n",
    "    'DINO': '#5f57db',\n",
    "    'Supervised': '#01041a'\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vox1o_scores = evaluate_sv(MODELS_RESNET, 'embeddings_vox1_avg.pt', trials=[\"voxceleb1_test_O\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vox1o_scores_ecapa = evaluate_sv(MODELS_ECAPA, 'embeddings_vox1_avg.pt', trials=[\"voxceleb1_test_O\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vox1h_scores_ecapa = evaluate_sv(MODELS_ECAPA, 'embeddings_vox1_avg.pt', trials=[\"voxceleb1_test_H\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Complementarity (Correlation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({model:vox1o_scores_ecapa[model][\"scores\"] for model in MODELS_ECAPA.keys()})\n",
    "# df = pd.DataFrame({model:vox1h_scores_ecapa[model][\"scores\"] for model in MODELS_ECAPA.keys()})\n",
    "\n",
    "correlation_matrix = df.corr()\n",
    "\n",
    "corr_long = correlation_matrix.reset_index().melt(id_vars=\"index\")\n",
    "corr_long.columns = [\"x\", \"y\", \"correlation\"]\n",
    "corr_long[\"x\"] = pd.Categorical(corr_long[\"x\"], categories=MODELS_ECAPA_ORDER, ordered=True)\n",
    "corr_long[\"y\"] = pd.Categorical(corr_long[\"y\"], categories=MODELS_ECAPA_ORDER[::-1], ordered=True)\n",
    "corr_long[\"label\"] = corr_long[\"correlation\"].map(\"{:.2f}\".format)\n",
    "\n",
    "p = (\n",
    "    ggplot(corr_long, aes(x='x', y='y', fill='correlation'))\n",
    "    + geom_tile(color='white')\n",
    "    + geom_text(aes(label='label'), size=10)\n",
    "    + scale_fill_gradient(\n",
    "        low='#c2d1ff', high='#4a78ff',\n",
    "        limits=(0.9, 1.0)\n",
    "    )\n",
    "    + labs(x=\"\", y=\"\", fill=\"Correlation\")\n",
    "    + theme_bw()\n",
    "    + theme(\n",
    "        figure_size=(5, 4.9),\n",
    "        text=element_text(size=14),\n",
    "        legend_title=element_blank(),\n",
    "        legend_position=\"none\",\n",
    "        panel_border=element_blank(),\n",
    "        axis_text_x=element_text(rotation=45, hjust=1)\n",
    "    )\n",
    ")\n",
    "\n",
    "# p.save('correlation.pdf')\n",
    "\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fusion (score-level)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sslsv.evaluations.CosineSVEvaluation import SpeakerVerificationEvaluation, SpeakerVerificationEvaluationTaskConfig\n",
    "from notebooks.evaluation.ScoreCalibration import ScoreCalibration\n",
    "\n",
    "\n",
    "class FusedAndCalibratedSVEvaluation(SpeakerVerificationEvaluation):\n",
    "    \n",
    "    def __init__(self, train_evaluations, test_evaluations, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.evaluations = test_evaluations\n",
    "        self.sc = ScoreCalibration(train_evaluations)\n",
    "\n",
    "    def _prepare_evaluation(self):\n",
    "        self.sc.train()\n",
    "    \n",
    "    def _get_sv_score(self, a, b):\n",
    "        scores = [evaluation._get_sv_score(a, b) for evaluation in self.evaluations]\n",
    "        score = self.sc.predict(torch.tensor(scores).unsqueeze(0))\n",
    "        return score.detach().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_ECAPA_SSL = {k:v for k, v in MODELS_ECAPA.items() if k != \"Supervised\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vox2fusion_scores_ecapa = evaluate_sv(MODELS_ECAPA_SSL, 'embeddings_vox2f_avg.pt', trials=[\"voxceleb2_test_fusion\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_evals = [vox2fusion_scores_ecapa[model][\"evaluation\"] for model in MODELS_ECAPA_SSL]\n",
    "test_evals = [vox1o_scores_ecapa[model][\"evaluation\"] for model in MODELS_ECAPA_SSL]\n",
    "\n",
    "evaluation = FusedAndCalibratedSVEvaluation(\n",
    "    train_evaluations=train_evals,\n",
    "    test_evaluations=test_evals,\n",
    "    model=None,\n",
    "    config=test_evals[0].config,\n",
    "    task_config=SpeakerVerificationEvaluationTaskConfig(\n",
    "        trials=['voxceleb1_test_O', 'voxceleb1_test_E', 'voxceleb1_test_H']\n",
    "        # trials=['voxceleb1_test_O']\n",
    "    ),\n",
    "    device='cpu'\n",
    ")\n",
    "\n",
    "evaluation.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, model in enumerate(MODELS_ECAPA_SSL):\n",
    "    print(model, evaluation.sc.model.W.weight[0, i].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation.sc.model.W.weight, evaluation.sc.model.W.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fusion (representations-level)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sslsv.evaluations.CosineSVEvaluation import CosineSVEvaluation, CosineSVEvaluationTaskConfig\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class ReprConcatenationSVEvaluation(CosineSVEvaluation):\n",
    "    \n",
    "    def __init__(self, evaluations, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.evaluations = evaluations\n",
    "\n",
    "    def _prepare_evaluation(self):\n",
    "        self.test_embeddings = {}\n",
    "\n",
    "        for k in self.evaluations[0].test_embeddings.keys():\n",
    "            self.test_embeddings[k] = torch.cat([\n",
    "                self.evaluations[i].test_embeddings[k]\n",
    "                for i in range(len(self.evaluations))\n",
    "            ], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_evals = [vox1o_scores_ecapa[model][\"evaluation\"] for model in MODELS_ECAPA_SSL]\n",
    "\n",
    "evaluation = ReprConcatenationSVEvaluation(\n",
    "    evaluations=test_evals,\n",
    "    model=None,\n",
    "    config=test_evals[0].config,\n",
    "    task_config=CosineSVEvaluationTaskConfig(\n",
    "        # trials=['voxceleb1_test_O', 'voxceleb1_test_E', 'voxceleb1_test_H']\n",
    "        trials=['voxceleb1_test_O']\n",
    "    ),\n",
    "    device='cpu'\n",
    ")\n",
    "\n",
    "evaluation.evaluate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = det_curve(get_models_for_visualization(vox1o_scores_ecapa))\n",
    "p += scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=MODELS_ECAPA_ORDER)\n",
    "p.save('det.pdf')\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = []\n",
    "for name, path in MODELS_ECAPA.items():\n",
    "    try:\n",
    "        with open(f'{path}/training.json', \"r\") as f:\n",
    "            train = json.load(f)\n",
    "    except:\n",
    "        continue\n",
    "    for epoch, metrics in train.items():\n",
    "        res.append({'Epoch': int(epoch), 'Model': name, **metrics})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "\n",
    "def create_plot(y, label):\n",
    "    p = (\n",
    "        ggplot(data, aes(x='Epoch', y=y, color='factor(Model)'))\n",
    "        + geom_line(size=1)\n",
    "        # + geom_point()\n",
    "        + scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=MODELS_ECAPA_ORDER)\n",
    "        + labs(x='Epoch', y=label, color='Models')\n",
    "        + theme_bw()\n",
    "        + theme(\n",
    "            figure_size=(8, 4.75),\n",
    "            text=element_text(size=14),\n",
    "            legend_position='top',\n",
    "            legend_title=element_blank(),\n",
    "            legend_key_spacing_x=15\n",
    "        )\n",
    "        + guides(color=guide_legend(nrow=1))\n",
    "    )\n",
    "    return p\n",
    "\n",
    "\n",
    "g_loss = create_plot('train/loss', 'Train loss')\n",
    "g_eer = create_plot('val/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')\n",
    "g_mindcf = create_plot('val/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')\n",
    "\n",
    "p = g_eer\n",
    "\n",
    "# p.save(\"convergence.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label-efficient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"Supervised\": {\n",
    "        100: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2/\",\n",
    "         50: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.5/\",\n",
    "         20: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.2/\",\n",
    "         10: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.1/\",\n",
    "          5: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.05/\",\n",
    "          2: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.02/\",\n",
    "          1: \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.01/\",\n",
    "    },\n",
    "    \"SSL (DINO)\": {\n",
    "        100: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-1.0/\",\n",
    "         50: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.5/\",\n",
    "         20: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.2/\",\n",
    "         10: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.1/\",\n",
    "          5: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.05/\",\n",
    "          2: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.02/\",\n",
    "          1: \"models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.01/\",\n",
    "    },\n",
    "}\n",
    "\n",
    "res = []\n",
    "for name, entry in MODELS.items():\n",
    "    for x, path in entry.items():\n",
    "        with open(f'{path}/training.json', \"r\") as f:\n",
    "            train = json.load(f)\n",
    "        res.append({'x': x, 'Model': name, **train[\"99\"]})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "\n",
    "p = (\n",
    "    ggplot(data, aes(x='x', y='val/sv_cosine/voxceleb1_test_O/eer', color='Model', group='Model'))\n",
    "    + geom_line(size=1.25)\n",
    "    + geom_point(size=3)\n",
    "    \n",
    "    + geom_segment(aes(x=100, y=2.5, xend=50, yend=2.5), size=0.75, color='#2b2b2b', arrow=arrow(type='closed', ends='last', length=0.1))\n",
    "    + annotate(\"text\", x=70, y=3.25, label='2x fewer\\nlabels', color='#2b2b2b', size=12)\n",
    "    \n",
    "    + geom_segment(aes(x=5, y=4, xend=1, yend=4), size=0.75, color='#2b2b2b', arrow=arrow(type='closed', ends='last', length=0.1))\n",
    "    + annotate(\"text\", x=2.25, y=4.75, label='5x fewer\\nlabels', color='#2b2b2b', size=12)\n",
    "\n",
    "    + scale_colour_manual(values=[\"#4a78ff\", \"#01041a\"])\n",
    "    # + scale_colour_manual(values=[\"#57d3db\", \"#2db4bd\", \"#db5f57\"])\n",
    "    + scale_x_log10(breaks=[1, 2, 5, 10, 20, 50, 100])\n",
    "    + scale_y_continuous(breaks=[1, 2, 4, 6, 8, 10])\n",
    "    + xlab(\"% of labeled data\")\n",
    "    + ylab(\"EER (%)\")\n",
    "    + theme_bw()\n",
    "    + theme(\n",
    "        figure_size=(5, 4.75),\n",
    "        text=element_text(size=14),\n",
    "        legend_title=element_blank(),\n",
    "        legend_position=\"top\",\n",
    "        legend_key_spacing_x=20\n",
    "        # legend_background=element_rect(fill='white', alpha=1.0, linetype='solid', color='#ebebeb')\n",
    "    )\n",
    ")\n",
    "\n",
    "# p.save('label_efficient.pdf')\n",
    "\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data-augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"MoCo\": {\n",
    "        100: \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "         75: \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-75/\",\n",
    "         50: \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-50/\",\n",
    "         25: \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-25/\",\n",
    "          0: \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-none/\",\n",
    "    },\n",
    "    \"DINO\": {\n",
    "        100: \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "         75: \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-75/\",\n",
    "         50: \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-50/\",\n",
    "         25: \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-25/\",\n",
    "          0: \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-none/\",\n",
    "    },\n",
    "    \"Supervised\": {\n",
    "        100: \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/\",\n",
    "         75: \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-75/\",\n",
    "         50: \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-50/\",\n",
    "         25: \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-25/\",\n",
    "          0: \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-none/\",\n",
    "    }\n",
    "}\n",
    "\n",
    "res = []\n",
    "for name, entry in MODELS.items():\n",
    "    for augprob, path in entry.items():\n",
    "        with open(f'{path}/evaluation.json', \"r\") as f:\n",
    "            eval = json.load(f)\n",
    "        res.append({'AugProb': augprob, 'Model': name, **eval})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "\n",
    "p = (\n",
    "    ggplot(data, aes(x='AugProb', y='test/sv_cosine/voxceleb1_test_O/eer', color='factor(Model)'))\n",
    "    + geom_line(size=1)\n",
    "    + geom_point(size=2)\n",
    "    + scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=[\"MoCo\", \"DINO\", \"Supervised\"])\n",
    "    + labs(x='% of data-augmentation', y='EER (%)', color='Models')\n",
    "    + theme_bw()\n",
    "    + theme(\n",
    "        figure_size=(8, 3.5),\n",
    "        text=element_text(size=14),\n",
    "        legend_position=\"top\",\n",
    "        legend_title=element_blank(),\n",
    "        legend_key_spacing_x=20\n",
    "    )\n",
    ")\n",
    "# p.save(\"data-aug.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intra/inter-speaker similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_ECAPA_PALETTE_ALPHA = {k:v + \"B3\" for k, v in MODELS_ECAPA_PALETTE.items()}\n",
    "MODELS_ECAPA_PALETTE_ALPHA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {k:f'{v}/embeddings_vox1_avg.pt' for k, v in MODELS_ECAPA.items()}\n",
    "# MODELS.update({f'{k}-sup':f'{v}/embeddings_vox1_avg.pt' for k, v in MODELS_ECAPA.items() if k == 'Supervised'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, stats = plot_intra_class_similarity('speaker', MODELS)\n",
    "p += scale_fill_manual(values=MODELS_ECAPA_PALETTE_ALPHA, limits=MODELS_ECAPA_ORDER)\n",
    "# p.save(\"intra-speaker.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, stats = plot_inter_class_similarity('speaker', MODELS, nb_samples=100)\n",
    "p += scale_fill_manual(values=MODELS_ECAPA_PALETTE_ALPHA, limits=MODELS_ECAPA_ORDER)\n",
    "# p.save(\"inter-speaker.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"MoCo\": {\n",
    "        \"Full\":     \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "        \"50% spk.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-half-spk/\",\n",
    "        \"50% utt.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-half-utt/\",\n",
    "        \"25% spk.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-quarter-spk/\",\n",
    "        \"25% utt.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-quarter-utt/\",\n",
    "    },\n",
    "    \"DINO\": {\n",
    "        \"Full\":     \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "        \"50% spk.\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-half-spk/\",\n",
    "        \"50% utt.\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-half-utt/\",\n",
    "        \"25% spk.\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-quarter-spk/\",\n",
    "        \"25% utt.\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-quarter-utt/\",\n",
    "    },\n",
    "    \"Supervised\": {\n",
    "        \"Full\":     \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/\",\n",
    "        \"50% spk.\": \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-half-spk/\",\n",
    "        \"50% utt.\": \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-half-utt/\",\n",
    "        \"25% spk.\": \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-quarter-spk/\",\n",
    "        \"25% utt.\": \"models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-quarter-utt/\",\n",
    "    }\n",
    "}\n",
    "\n",
    "OPTIONS = {\n",
    "    \"Full\": \"#001233\",\n",
    "    \"50% spk.\": \"#023e7d\",\n",
    "    \"50% utt.\": \"#0466c8\",\n",
    "    \"25% spk.\": \"#76c893\",\n",
    "    \"25% utt.\": \"#d9ed92\"\n",
    "}\n",
    "\n",
    "res = []\n",
    "for name, entry in MODELS.items():\n",
    "    for option, path in entry.items():\n",
    "        with open(f'{path}/evaluation.json', \"r\") as f:\n",
    "            eval = json.load(f)\n",
    "        res.append({'Option': option, 'Model': name, **eval})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "data['Option'] = pd.Categorical(data['Option'], categories=OPTIONS, ordered=True)\n",
    "\n",
    "p = (\n",
    "    ggplot(data, aes(x='factor(Model)', y='test/sv_cosine/voxceleb1_test_O/eer', fill='factor(Option)'))\n",
    "    + geom_bar(stat='identity', position='dodge', width=0.7)\n",
    "    + scale_fill_manual(values=OPTIONS, limits=list(OPTIONS.keys()))\n",
    "    # + scale_fill_brewer(type=\"seq\", palette=\"Blues\", direction=-1)\n",
    "    + scale_x_discrete(limits=[\"MoCo\", \"DINO\", \"Supervised\"])\n",
    "    + coord_cartesian(ylim=(2.0, 10.5))\n",
    "    + labs(x='', y='EER (%)', fill='Training distribution')\n",
    "    + theme_bw()\n",
    "    + theme(\n",
    "        figure_size=(8, 4.75),\n",
    "        text=element_text(size=14),\n",
    "        legend_title=element_blank(),\n",
    "        legend_position=\"top\",\n",
    "        legend_key_spacing_x=20\n",
    "    )\n",
    ")\n",
    "# p.save(\"training_distribution.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"SimCLR\": {\n",
    "        \"SSL pos. sampling\":  \"models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/\",\n",
    "        \"Supervised pos. sampling\": \"models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03_sup2/\",\n",
    "    },\n",
    "    \"MoCo\": {\n",
    "        \"SSL pos. sampling\":  \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "        \"Supervised pos. sampling\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_sup2/\",\n",
    "    },\n",
    "    \"SwAV\": {\n",
    "        \"SSL pos. sampling\":  \"models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/\",\n",
    "        \"Supervised pos. sampling\": \"models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1_sup2/\",\n",
    "    },\n",
    "    \"VICReg\": {\n",
    "        \"SSL pos. sampling\":  \"models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/\",\n",
    "        \"Supervised pos. sampling\": \"models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1_sup2/\",\n",
    "    },\n",
    "    \"DINO\": {\n",
    "        \"SSL pos. sampling\":  \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "        \"Supervised pos. sampling\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_sup2/\",\n",
    "    },\n",
    "}\n",
    "\n",
    "OPTIONS = {\n",
    "    \"SSL pos. sampling\": \"#01041a\",\n",
    "    \"Supervised pos. sampling\": \"#4a78ff\",\n",
    "}\n",
    "\n",
    "res = []\n",
    "for name, entry in MODELS.items():\n",
    "    for option, path in entry.items():\n",
    "        with open(f'{path}/nmi.json', \"r\") as f:\n",
    "            eval = json.load(f)\n",
    "        res.append({\n",
    "            'Option': option,\n",
    "            'Model': name,\n",
    "            'nmi_ratio': eval['vox1_nmi_speaker'] / eval['vox1_nmi_video']\n",
    "        })\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "data['Option'] = pd.Categorical(data['Option'], categories=OPTIONS, ordered=True)\n",
    "\n",
    "p = (\n",
    "    ggplot(data, aes(x='factor(Model)', y='nmi_ratio', fill='factor(Option)'))\n",
    "    + geom_bar(stat='identity', position='dodge', width=0.7)\n",
    "    + scale_fill_manual(values=OPTIONS, limits=[\"SSL pos. sampling\", \"Supervised pos. sampling\"])\n",
    "    + scale_x_discrete(limits=[\"SimCLR\", \"MoCo\", \"SwAV\", \"VICReg\", \"DINO\"])\n",
    "    + coord_cartesian(ylim=(0.92, 1.12))\n",
    "    + labs(x='', y='Speaker-to-Recording NMI Ratio', fill='Pos. sampling.')\n",
    "    + theme_bw()\n",
    "    + theme(\n",
    "        figure_size=(8, 4.75),\n",
    "        text=element_text(size=14),\n",
    "        legend_title=element_blank(),\n",
    "        legend_position=\"top\",\n",
    "        legend_key_spacing_x=20\n",
    "    )\n",
    ")\n",
    "# p.save(\"nmi.pdf\")\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pivoted = data.pivot(index=\"Model\", columns=\"Option\", values=\"nmi_ratio\")\n",
    "pivoted[\"relative_improvement\"] = (\n",
    "    (pivoted[\"Supervised pos. sampling\"] - pivoted[\"SSL pos. sampling\"])\n",
    "    / pivoted[\"SSL pos. sampling\"]\n",
    ")\n",
    "pivoted[\"absolute_improvement\"] = (\n",
    "    (pivoted[\"Supervised pos. sampling\"] - pivoted[\"SSL pos. sampling\"])\n",
    ")\n",
    "pivoted[\"relative_improvement\"].mean(), pivoted[\"absolute_improvement\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collapse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"Baseline\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-default\",\n",
    "    \"No negatives\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-nonegs\",\n",
    "    \"High temp.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-hightemp\",\n",
    "    \"Low temp.\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-lowtemp\",\n",
    "    \"No momentum\": \"models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-nomom\",\n",
    "}\n",
    "\n",
    "res = []\n",
    "for model, path in MODELS.items():\n",
    "    with open(f'{path}/debug.json', \"r\") as f:\n",
    "        debug = json.load(f)\n",
    "    for step, metrics in debug.items():\n",
    "        step = int(step)\n",
    "        if step > 12000:\n",
    "            break\n",
    "        if step % 100 != 0:\n",
    "            continue\n",
    "        res.append({'Step': step, 'Model': model, **metrics})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "\n",
    "def create_plot(y, label):\n",
    "    p = (\n",
    "        ggplot(data, aes(x='Step', y=y, color='factor(Model)', linetype='factor(Model)', size='factor(Model)'))\n",
    "        + geom_line()\n",
    "        + labs(x='Step', y=label, color='Models')\n",
    "        + scale_color_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": \"#01041a\",\n",
    "                \"Low temp.\": \"#3572d4\",\n",
    "                \"High temp.\": \"#3572d4\",\n",
    "                \"No negatives\": \"#e23d3d\",\n",
    "                \"No momentum\": \"#8f8f8f\"\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No negatives\", \"Low temp.\", \"High temp.\", \"No momentum\"],\n",
    "        )\n",
    "        + scale_linetype_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": \"solid\",\n",
    "                \"Low temp.\": \"solid\",\n",
    "                \"High temp.\": \"dashed\",\n",
    "                \"No negatives\": \"solid\",\n",
    "                \"No momentum\": \"solid\"\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No negatives\", \"Low temp.\", \"High temp.\", \"No momentum\"],\n",
    "        )\n",
    "        + scale_size_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": 1.0,\n",
    "                \"Low temp.\": 0.75,\n",
    "                \"High temp.\": 0.75,\n",
    "                \"No negatives\": 0.75,\n",
    "                \"No momentum\": 0.75\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No negatives\", \"Low temp.\", \"High temp.\", \"No momentum\"],\n",
    "        )\n",
    "        + theme_bw()\n",
    "        + theme(\n",
    "            figure_size=(6, 3.5),\n",
    "            text=element_text(size=14),\n",
    "            legend_title=element_blank(),\n",
    "            legend_position=\"none\",\n",
    "            legend_key_spacing_x=20,\n",
    "        )\n",
    "        + guides(color=guide_legend(nrow=2))\n",
    "    )\n",
    "    return p\n",
    "\n",
    "\n",
    "g_loss = create_plot('train/loss', 'Loss')\n",
    "g_h = create_plot('train/h', 'Contrastive Entropy')\n",
    "g_std = create_plot('train/std', 'Embeddings Std.')\n",
    "\n",
    "\n",
    "g_h.save(\"collapse_moco_h.pdf\")\n",
    "g_std.save(\"collapse_moco_std.pdf\")\n",
    "\n",
    "g_h, g_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    \"No momentum\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-nomom\",\n",
    "    \"Baseline\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-default\",\n",
    "    \"No centering\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-nocentering\",\n",
    "    \"No sharpening\": \"models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-nosharpening\",\n",
    "}\n",
    "\n",
    "res = []\n",
    "for model, path in MODELS.items():\n",
    "    with open(f'{path}/debug.json', \"r\") as f:\n",
    "        debug = json.load(f)\n",
    "    for step, metrics in debug.items():\n",
    "        step = int(step)\n",
    "        if step > 25000:\n",
    "            break\n",
    "        if step % 200 != 0:\n",
    "            continue\n",
    "        res.append({'Step': step, 'Model': model, **metrics})\n",
    "\n",
    "data = pd.DataFrame(res)\n",
    "\n",
    "order = ['No centering', 'No sharpening', 'No momentum', 'Baseline']\n",
    "rank  = {k: i for i, k in enumerate(order)}\n",
    "data  = data.sort_values('Model', key=lambda s: s.map(rank))\n",
    "\n",
    "def create_plot(y, label):\n",
    "    p = (\n",
    "        ggplot(data, aes(x='Step', y=y, color='factor(Model)', linetype='factor(Model)', size='factor(Model)'))\n",
    "        + geom_line()\n",
    "        + labs(x='Step', y=label, color='Models')\n",
    "        # + coord_cartesian(xlim=(0, 5000))\n",
    "        + scale_color_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": \"#01041a\",\n",
    "                \"No centering\": \"#e23d3d\",\n",
    "                \"No sharpening\": \"#e23d3d\",\n",
    "                \"No momentum\": \"#8f8f8f\"\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No centering\", \"No sharpening\", \"No momentum\"],\n",
    "        )\n",
    "        + scale_linetype_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": \"solid\",\n",
    "                \"No centering\": \"solid\",\n",
    "                \"No sharpening\": \"dashed\",\n",
    "                \"No momentum\": \"solid\"\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No centering\", \"No sharpening\", \"No momentum\"],\n",
    "        )\n",
    "        + scale_size_manual(\n",
    "            name='Models',\n",
    "            values={\n",
    "                \"Baseline\": 1.0,\n",
    "                \"No centering\": 0.75,\n",
    "                \"No sharpening\": 0.75,\n",
    "                \"No momentum\": 0.75\n",
    "            },\n",
    "            limits=[\"Baseline\", \"No centering\", \"No sharpening\", \"No momentum\"],\n",
    "        )\n",
    "        + theme_bw()\n",
    "        + theme(\n",
    "            figure_size=(6, 3.5),\n",
    "            text=element_text(size=14),\n",
    "            legend_title=element_blank(),\n",
    "            legend_position=\"none\",\n",
    "            legend_key_spacing_x=20,\n",
    "        )\n",
    "        + guides(color=guide_legend(nrow=2))\n",
    "    )\n",
    "    return p\n",
    "\n",
    "\n",
    "g_h = create_plot('train/teacher_h', 'Teacher Entropy')\n",
    "g_kl = create_plot('train/kl_div', 'Teacher-Student KL div.')\n",
    "g_std = create_plot('train/teacher_std', 'Embeddings Std.')\n",
    "\n",
    "g_h.save(\"collapse_dino_h.pdf\")\n",
    "g_kl.save(\"collapse_dino_kl.pdf\")\n",
    "g_std.save(\"collapse_dino_std.pdf\")\n",
    "\n",
    "g_h, g_kl, g_std"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_py-3.13.3_torch-2.7.1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
